-
Notifications
You must be signed in to change notification settings - Fork 12k
kv-cache : refactor + add llama_memory_state_i #13746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
d23f887
to
8323e23
Compare
c1434b8
to
1eec34a
Compare
This PR should not cause any performance changes and the numerical results should be mostly the same (with some small exceptions due to the new logic in Would appreciate some testing and reports for regressions. Thanks. |
I re-run the ppl test from #13194 (comment) master at aa50ba4
This PR:
Some results changed very slightly, so I'm not sure if this is expect |
Yes, I think this difference is expected for SWA models (note Phi currently is disabled SWA, so no difference). It's caused by the different order in which we place the data in memory, due to the |
Yes that's right, I added
Edit: except for |
This comment was marked as resolved.
This comment was marked as resolved.
I re-run the test and the ppl stays the same as my last comment. Btw, just thinking, is it possible (and it is useful) to add a ppl test mode that uses the KV remove API? |
The ./bin/llama-perplexity -hf bartowski/gemma-2-9b-it-GGUF:Q4_K_M -f ./wikitext-2-raw/wiki.test.raw -c 16384 -fa --chunks 2 --swa-full Maybe your reference value on
Can you clarify? |
I can't run the ppl rn, but if you get correct result, then I think yes could be a problem on my side.
Currently, AFAIU the ppl test simply evaluate text chunk by chunk, but only going forward. For example, if I have 3 chunks: 1-2-3, then they will be evaluated in the order of 1-2-3 But what we also what to test is for example:
So I expect the ppl to be the same as just doing 1-2-3 |
3ef770f
to
0b73da5
Compare
How does this recover from a failed call to |
There are some tricky scenarios in which we could have overwritten some of the data in the cache by the time the error occurs (i.e. processed the first few ubatches, but not all of them yet). Before (i.e. on I think that on compute error, the KV cache should be assumed in an undefined state and the application should take necessary steps to recover (i.e. by clearing it and reprocessing the context that is currently needed). Later on, this reprocessing will become seamless, when we start storing the necessary tokens/embeddings information and add the logic for auto-reprocessing whatever is currently missing from the cache. |
I am mostly concerned about the abort callback functionality. Errors in the backend are likely to be unrecoverable, but I am not sure if the abort functionality makes sense if it leaves the cache in a bad state. |
I admit that I had completely forgotten about the abort callback. Let me see if we can do something about this. |
0b73da5
to
2252eef
Compare
Drafting for now as I want to do some more testing and think about the abort mechanism. |
const llama_seq_id seq_id = ubatch.seq_id[i][0]; | ||
|
||
// can we use this cell? either: | ||
// - the cell is empty | ||
// - the cell is occupied only by one sequence: | ||
// - mask causally, if the sequence is the same as the one we are inserting | ||
// - mask SWA, using current max pos for that sequence in the cache | ||
// always insert in the cell with minimum pos | ||
bool can_use = cells.is_empty(head_cur + i); | ||
|
||
if (!can_use && cells.seq_count(head_cur + i) == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that would automatically disqualify all of the other logic around reusing full cells?
Assuming this is correct, I think this would be the correct approach?
const llama_seq_id seq_id = ubatch.seq_id[i][0]; | |
// can we use this cell? either: | |
// - the cell is empty | |
// - the cell is occupied only by one sequence: | |
// - mask causally, if the sequence is the same as the one we are inserting | |
// - mask SWA, using current max pos for that sequence in the cache | |
// always insert in the cell with minimum pos | |
bool can_use = cells.is_empty(head_cur + i); | |
if (!can_use && cells.seq_count(head_cur + i) == 1) { | |
// can we use this cell? either: | |
// - the cell is empty | |
// - the cell is occupied only by one sequence: | |
// - mask causally, if the sequence is the same as the one we are inserting | |
// - mask SWA, using current max pos for that sequence in the cache | |
// always insert in the cell with minimum pos | |
bool can_use = cells.is_empty(head_cur + i); | |
if (!can_use && cells.seq_count(head_cur + i) == 1 && ubatch.n_seqs == 1) { | |
const llama_seq_id seq_id = ubatch.seq_id[0][0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That diff is gross, but it just adds an extra conditional to the outer check that checks whether ubatch.n_seqs == 1
and then always uses ubatch.seq_id[0][0]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gabe-l-hart It should not be necessary to limit this branch to when ubatch.n_seqs
to 1
. This almost never happens for simple splits anyway, except when n_ubatch
is 1
.
See #13746 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition! Thanks thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that we should simplify llama_ubatch
by enforcing that the tokens in the batch belong to only one sequence. The use-case for multiple sequences per input token is very rare and can trivially be achieved with llama_kv_self_seq_cp()
if needed. Hence I added these TODOs:
Lines 21 to 22 in 9548d2a
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence | |
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Simplifying the multi-sequence-per-token logic would certainly help from a clarity perspective (having recently tried to understand the current implementation and been only partially successful).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use-case for multiple sequences per input token is very rare
If the public llama_batch
API is changed (or at least common_batch_add
), I think the only places where multiple sequences per input token are used are in tools/perplexity/perplexity.cpp
for hellaswag
, winogrande
and multiple-choice
benchmarks (I don't know if any 3rd party project uses that feature (multiple sequences per input token), though).
I can say this will simplify part of the recurrent cache's find_slot
(since it did attempt to handle multi-sequence tokens, at least enough to make hellaswag
run properly).
Another reason to remove this is that it is not obvious what should happen when using multiple sequences for a new token when the sequences have already diverged. The current behavior is to use the first seq_id
and overwrite the states of the other sequences part of that token (at least in for recurrent cache). I'm not sure how that case is handled for the unified cache, but this case is very hard to handle correctly (not even sure what the correct behavior should be here). (This case doesn't really happen in practice, though, since multiple sequences per input tokens are very rarely used, and also not in this way. But the problem is that they could be, and it leads to confusing behavior.)
825efad
to
eed741e
Compare
@slaren In eed741e I think I managed to extract the We no longer pass the memory object when building the compute graphs. Instead, we prepare a memory state for each ubatch and we pass this state to the graph building context. The memory state carries the necessary information about the current I was also able to elegantly replace the Sorry for the large diff again. Let me know if you have any follow-up comments or suggestions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. Some notes for the future:
-
init_batch
andinit_full
could be moved tollama_memory_i
. There are still many places wherellama_kv_cache
is used directly instead ofllama_memory_i
.llama_kv_cache_recurrent
probably should not inherit fromllama_kv_cache
, but rather it should be a completely separate implementation ofllama_memory_i
. I believe that's already the plan, andllama_kv_cache
is only used in this way to simplify the transition, but ultimately functions likellama_decode
should not depend onllama_kv_cache
. -
llama_kv_cache
could probably be renamed tollama_kv_cache_i
for consistency -
From what I can tell,
llama_kv_cache_unified_state_i
,llama_kv_cache_unified_iswa_state_i
,llama_kv_cache_recurrent_state_i
do not need to be interfaces. Is the goal is to hide the implementation details from the header? Since virtual functions have a nonzero performance cost, I would be wary about turning everything into an interface and adding an indirection to every function call, even if it is not likely to have a significant performance impact at the moment.
9d05381
to
2b984f4
Compare
Updated - these are no longer interfaces.
Yes, the goal is to completely migrate to the I set the PR ready for review and planning to merge this soon, unless we spot any regressions. The next short-term steps will be:
At this point, I think we should focus on refactoring the |
ggml-ci
ggml-ci
f23e4cc
to
71619f2
Compare
cont #13706 (comment), #13194
Main goal here is to simplify the abstract interface of
struct llama_kv_cache
.Overview
Changes to the internal
struct llama_kv_cache
abstract interface:llama_kv_cache::commit()
llama_kv_cache::restore()
llama_kv_cache::sbatch_init()
llama_kv_cache::ubatch_next()
llama_kv_cache::find_slot()
llama_kv_cache_guard
This new interface changes the logic in
llama_decode()
to first make sure that we can fit the input batch into the cache and only after that we start to process the ubatches. This check takes correctly into account SWA masking and also makes sure that the cache will not be modified before we start the actual computation.note: the latter is not yet true for the recurrent cache - see comments in the code
Another important update in this PR is that the
find_slot()
logic for unified caches is now improved. Before we looked for a slot (i.e. a set of contiguous cells) that is empty in order to place the ubatch in it. We now allow the slot to contain data from the same or other sequence which is masked (either by causality or by SWA):llama.cpp/src/llama-kv-cache.cpp
Lines 574 to 621 in 2252eef
This change is needed for the next PR, which will optimize the SWA cache to use just
n_swa + n_ubatch
cells and it also has some other nice properties. For example, we no longer have to explicitly prune tokens on successful batch processing, which simplifies the logic significantly and allows us to re-enable speculative decoding for SWA models (will be done also in the next PR).The worst-graph reserve logic is also refactored and simplified significantly.
There are also some changes to
llama-batch
, but these are mainly to patch things up so that we are able to push the KV cache refactor first. So no need to review thellama-batch
in deep details - the code there will be reworked soon.TODO
Next PRs
llama_decode
, so that user code does not have to do it (llama : auto-batch preparation #13845)n_swa + n_ubatch
for SWA cache (llama : use n_swa + n_ubatch cells for SWA cache #13833)